"""
Phase 4: Fiscal Dominance Regime Classification
=================================================
Construct fiscal stress indicators and test whether demographics predict
transitions into fiscal dominance regimes.

Input:  fiscal_dominance/data/processed/fiscal_panel.csv
Output: fiscal_dominance/output/tables/phase4_*.csv
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path
from scipy import stats

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
FD_DIR = PROJECT_DIR / "fiscal_dominance"
PROCESSED_DIR = FD_DIR / "data" / "processed"
TABLE_DIR = FD_DIR / "output" / "tables"

sys.path.insert(0, str(PROJECT_DIR / "multilateral" / "src"))
from model import PanelGLS


def fit_and_report(y, X, entity_ids, time_ids, feature_names, label):
    """Fit PanelGLS and return summary DataFrame."""
    model = PanelGLS()
    model.fit(y, X, entity_ids, time_ids)
    print(f"\n{'=' * 70}")
    print(f"  {label}")
    print(f"  N={model.n_obs:,}, {model.n_countries} countries, "
          f"R²={model.r_squared:.4f}, rho={model.rho:.3f}")
    print(f"{'=' * 70}")
    model.summary(feature_names=feature_names)
    result_df = model.to_dataframe(feature_names=feature_names)
    result_df['model'] = label
    result_df['n_obs'] = model.n_obs
    result_df['n_countries'] = model.n_countries
    result_df['r_squared'] = model.r_squared
    result_df['rho'] = model.rho
    return model, result_df


def main():
    print("=" * 70)
    print("PHASE 4: Fiscal Dominance Regime Classification")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "fiscal_panel.csv")
    print(f"Loaded: {len(df):,} obs, {df['iso3'].nunique()} countries")

    all_results = []
    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['kaopen', 'log_rel_opw', 'nfa_gdp_lag']
    controls = [c for c in controls if c in df.columns]

    # =================================================================
    # 1. Binary Fiscal Dominance Indicator
    #    FD = 1 if: r-g > 0 AND primary deficit AND debt > 60%
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  Constructing Fiscal Dominance Indicators")
    print("=" * 70)

    df['fd_binary'] = (
        (df['r_minus_g'] > 0) &
        (df['primary_bal_gdp'] < 0) &
        (df['govt_debt_gdp'] > 60)
    ).astype(float)
    df.loc[
        df['r_minus_g'].isna() | df['primary_bal_gdp'].isna() | df['govt_debt_gdp'].isna(),
        'fd_binary'
    ] = np.nan

    n_fd = df['fd_binary'].sum()
    n_total = df['fd_binary'].notna().sum()
    print(f"  Binary FD indicator: {n_fd:.0f}/{n_total} obs ({100*n_fd/n_total:.1f}%)")

    # Country-level prevalence
    country_fd = df.groupby('iso3').agg(
        fd_share=('fd_binary', 'mean'),
        n_years=('fd_binary', 'count'),
        mean_debt=('govt_debt_gdp', 'mean'),
        mean_oadr=('old_dep', 'mean'),
    ).sort_values('fd_share', ascending=False)
    print(f"\n  Top-15 countries by FD prevalence:")
    print(country_fd.head(15).to_string(float_format='%.3f'))
    country_fd.to_csv(TABLE_DIR / "phase4_country_fd_prevalence.csv")

    # =================================================================
    # 2. Continuous Fiscal Stress Index
    #    Standardized composite of: r-g, debt/GDP, primary deficit
    # =================================================================
    stress_components = []
    for var, flip in [('r_minus_g', False), ('govt_debt_gdp', False), ('primary_bal_gdp', True)]:
        if var in df.columns:
            series = df[var].copy()
            if flip:
                series = -series  # deficit = higher stress
            z_score = (series - series.mean()) / series.std()
            stress_components.append(z_score)

    if len(stress_components) == 3:
        df['fiscal_stress'] = sum(stress_components) / 3
        df.loc[df[['r_minus_g', 'govt_debt_gdp', 'primary_bal_gdp']].isna().any(axis=1),
               'fiscal_stress'] = np.nan
        print(f"\n  Fiscal stress index: {df['fiscal_stress'].notna().sum():,} obs")
        print(f"    Mean={df['fiscal_stress'].mean():.3f}, Std={df['fiscal_stress'].std():.3f}")
        print(f"    Range: [{df['fiscal_stress'].min():.2f}, {df['fiscal_stress'].max():.2f}]")

    # =================================================================
    # 3. Z -> Fiscal Stress Index (continuous)
    # =================================================================
    if 'fiscal_stress' in df.columns:
        dep_var = 'fiscal_stress'
        vars_3 = demo_vars + controls
        est3 = df.dropna(subset=[dep_var] + vars_3).copy()

        if len(est3) >= 200:
            m3, r3 = fit_and_report(
                est3[dep_var].values, est3[vars_3].values,
                est3['iso3'].values, est3['year'].values,
                vars_3, "Model 1: Z -> Fiscal Stress Index"
            )
            all_results.append(r3)

        # With KAOPEN interactions
        int_vars = ['Z_1_x_kaopen', 'Z_2_x_kaopen', 'Z_3_x_kaopen']
        int_vars = [v for v in int_vars if v in df.columns]
        if int_vars:
            vars_3k = demo_vars + controls + int_vars
            est3k = df.dropna(subset=[dep_var] + vars_3k).copy()

            if len(est3k) >= 200:
                m3k, r3k = fit_and_report(
                    est3k[dep_var].values, est3k[vars_3k].values,
                    est3k['iso3'].values, est3k['year'].values,
                    vars_3k, "Model 2: Z -> Fiscal Stress + KAOPEN interactions"
                )
                all_results.append(r3k)

    # =================================================================
    # 4. Transition Probability: Probit-style via LPM
    #    P(FD=1) = f(Z, debt_lag, controls)
    #    (Use LPM via PanelGLS since we lack probit in the toolkit)
    # =================================================================
    dep_var = 'fd_binary'
    vars_4 = demo_vars + ['debt_lag'] + controls
    est4 = df.dropna(subset=[dep_var] + vars_4).copy()

    if len(est4) >= 200:
        m4, r4 = fit_and_report(
            est4[dep_var].values, est4[vars_4].values,
            est4['iso3'].values, est4['year'].values,
            vars_4, "Model 3: LPM P(FD=1) = f(Z, debt_lag, controls)"
        )
        all_results.append(r4)

        print("\n  >>> TRANSITION PROBABILITY:")
        for zv in demo_vars:
            idx = vars_4.index(zv)
            coef = m4.beta[idx]
            p = m4.pvalues[idx]
            print(f"    {zv}: marginal effect = {coef:.4f} (p={p:.4f})")

    # OADR specification
    if 'old_dep' in df.columns:
        df['old_dep_sq'] = df['old_dep'] ** 2
        vars_4b = ['old_dep', 'old_dep_sq', 'debt_lag'] + controls
        est4b = df.dropna(subset=[dep_var] + vars_4b).copy()

        if len(est4b) >= 200:
            m4b, r4b = fit_and_report(
                est4b[dep_var].values, est4b[vars_4b].values,
                est4b['iso3'].values, est4b['year'].values,
                vars_4b, "Model 3b: LPM P(FD=1) with OADR Quadratic"
            )
            all_results.append(r4b)

            # Find turning point
            idx_lin = vars_4b.index('old_dep')
            idx_sq = vars_4b.index('old_dep_sq')
            b1, b2 = m4b.beta[idx_lin], m4b.beta[idx_sq]
            if b2 != 0:
                turning = -b1 / (2 * b2)
                print(f"\n  >>> OADR turning point for FD probability: {turning:.1f}%")

    # =================================================================
    # 5. Z -> Fiscal Stress with KAOPEN interaction
    # =================================================================
    if 'fiscal_stress' in df.columns and 'kaopen' in df.columns:
        # Already done above in Model 2

        # Triple: FD probability with KAOPEN
        df['Z1_x_kaopen_fd'] = df['Z_1'] * df['kaopen']
        df['Z2_x_kaopen_fd'] = df['Z_2'] * df['kaopen']
        df['Z3_x_kaopen_fd'] = df['Z_3'] * df['kaopen']
        triple_vars = ['Z1_x_kaopen_fd', 'Z2_x_kaopen_fd', 'Z3_x_kaopen_fd']

        vars_5 = demo_vars + ['debt_lag', 'kaopen'] + triple_vars + [
            c for c in controls if c != 'kaopen']
        est5 = df.dropna(subset=['fd_binary'] + vars_5).copy()

        if len(est5) >= 200:
            m5, r5 = fit_and_report(
                est5['fd_binary'].values, est5[vars_5].values,
                est5['iso3'].values, est5['year'].values,
                vars_5, "Model 4: LPM P(FD) + KAOPEN interaction"
            )
            all_results.append(r5)

    # =================================================================
    # 6. CA -> Fiscal Stress channel (mediation from Project 1)
    # =================================================================
    if 'ca_gdp' in df.columns and 'fiscal_stress' in df.columns:
        vars_6 = ['ca_gdp'] + demo_vars + controls
        est6 = df.dropna(subset=['fiscal_stress'] + vars_6).copy()

        if len(est6) >= 200:
            m6, r6 = fit_and_report(
                est6['fiscal_stress'].values, est6[vars_6].values,
                est6['iso3'].values, est6['year'].values,
                vars_6, "Model 5: Fiscal Stress = f(CA/GDP, Z, controls)"
            )
            all_results.append(r6)

    # =================================================================
    # Save regime indicators back to panel
    # =================================================================
    regime_cols = ['iso3', 'year', 'fd_binary', 'fiscal_stress']
    regime_cols = [c for c in regime_cols if c in df.columns]
    df[regime_cols].to_csv(PROCESSED_DIR / "fiscal_regimes.csv", index=False)
    print(f"\nSaved regime indicators: {PROCESSED_DIR / 'fiscal_regimes.csv'}")

    # =================================================================
    # Save all results
    # =================================================================
    if all_results:
        results_df = pd.concat(all_results, ignore_index=True)
        results_df.to_csv(TABLE_DIR / "phase4_regime_results.csv", index=False)
        print(f"\n{'=' * 70}")
        print(f"Saved: {TABLE_DIR / 'phase4_regime_results.csv'}")
        print(f"  {len(results_df)} rows across {results_df['model'].nunique()} models")

    return results_df if all_results else pd.DataFrame()


if __name__ == "__main__":
    results = main()
